from pointops import knn_query
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn
from Utils.Registry import Registry
from Utils.Tool import index_points

point_focused = Registry('PointFocusedAttention')
loc_pos = Registry('LocalityPositionEncoding')
glob_pos = Registry('GlobalPositionEncoding')
loc_perc = Registry('LocalityPerception')
glob_indu = Registry('GlobalInducing')


class KNN_Query:
    def __init__(self, neighbours):
        self.neighbours = neighbours

    def __call__(self, xyz, new_xyz):
        with torch.no_grad():
            batch, N, C = new_xyz.shape
            offset = torch.cumsum(torch.tensor([sample.shape[0] for sample in xyz]), dim=-1).to(new_xyz.device)
            new_offset = torch.cumsum(torch.tensor([sample.shape[0] for sample in new_xyz]), dim=-1).to(new_xyz.device)
            idx, _ = knn_query(self.neighbours, xyz.reshape(-1, C).contiguous(), offset, new_xyz.reshape(-1, C).contiguous(), new_offset)
            idx = idx.reshape(batch, -1, self.neighbours) - rearrange(new_offset - new_offset[0], 's -> s 1 1')
            return idx


@loc_pos.register_module('withLocPos')
class LocalityPosition(nn.Module):
    def __init__(self, channels):
        super(LocalityPosition, self).__init__()
        self.pos_emb = nn.Sequential(
            nn.Linear(3, channels),
            nn.ReLU(inplace=True),
            nn.Linear(channels, channels)
        )

    def forward(self, attn_wights, coordinates, q, k, reference_index):
        pos_emb = self.pos_emb(coordinates[:, :, None] - index_points(coordinates, reference_index))
        bias = q[:, :, None] @ pos_emb.mT + (index_points(k, reference_index) * pos_emb).sum(dim=-1, keepdim=True).mT
        return attn_wights + bias


@glob_pos.register_module('withGlobPos')
class GlobalPosition(nn.Module):
    def __init__(self, channels, T):
        super(GlobalPosition, self).__init__()
        self.channels = channels
        self.T = nn.Parameter(torch.zeros((T, self.channels)), requires_grad=True)
        nn.init.trunc_normal_(self.T, 0.02)

    def forward(self, attn_weights, q):
        batch, _, _ = attn_weights.shape
        bias = q @ self.T[None].expand(batch, -1, -1).mT
        return attn_weights + bias


@glob_pos.register_module('withoutGlobPos')
@loc_pos.register_module('withoutLocPos')
class withoutLocalityPosition(nn.Module):
    def __init__(self, **args):
        super(withoutLocalityPosition, self).__init__()

    def forward(self, *inputs):
        return inputs[0]


@loc_perc.register_module()
class LocalityPerception(nn.Module):
    def __init__(self, channels, neighbours, loc_pos_cfgs, bias=True):
        super(LocalityPerception, self).__init__()

        self.qkv = nn.Linear(channels, channels * 3, bias=bias)
        self.knn = KNN_Query(neighbours)

        loc_pos_cfgs.channels = channels
        self.pos_enc = loc_pos.build(loc_pos_cfgs)

    def forward(self, coordinates, features):
        batch, _, channels = features.shape
        q, k, v = self.qkv(features).reshape(batch, -1, 3, channels).permute(2, 0, 1, 3)
        reference_index = self.knn(coordinates, coordinates)
        attn_weights = q[:, :, None] @ index_points(k, reference_index).mT
        attn_weights = self.pos_enc(attn_weights, coordinates, q, k, reference_index)
        return attn_weights, index_points(v, reference_index)


@glob_indu.register_module()
class GlobalInducing(nn.Module):
    def __init__(self, inducing, channels, glob_pos_cfgs, bias=True):
        super(GlobalInducing, self).__init__()
        self.channels = channels

        self.inducing = nn.Parameter(torch.zeros((inducing, channels)), requires_grad=True)
        nn.init.trunc_normal_(self.inducing, 0.02)

        self.qkv = nn.Linear(channels, 3 * channels, bias=bias)
        self.softmax = nn.Softmax(dim=-1)
        self.inducing_kv = nn.Linear(channels, 2 * channels, bias=bias)

        glob_pos_cfgs.channels = channels
        glob_pos_cfgs.T = inducing
        self.glob_enc = glob_pos.build(glob_pos_cfgs)

    def forward(self, features):
        batch, _, channels = features.shape
        q, k, v = self.qkv(features).reshape(batch, -1, 3, channels).permute(2, 0, 1, 3)
        inducing = self.inducing[None].expand(batch, -1, -1)
        attn_weights = self.softmax(inducing @ k.mT / torch.sqrt(torch.tensor(self.channels, dtype=torch.float32)))
        attn_inducing = attn_weights @ v

        inducing_k, inducing_v = self.inducing_kv(attn_inducing).reshape(batch, -1, 2, channels).permute(2, 0, 1, 3)
        attn_weights = q @ inducing_k.mT
        attn_weights = self.glob_enc(attn_weights, q)
        return attn_weights, inducing_v


@point_focused.register_module()
class PointFocusedAttention(nn.Module):
    def __init__(self, channels, loc_perc_cfgs, glob_indu_cfgs, bias=True):
        super(PointFocusedAttention, self).__init__()
        self.channels = channels

        loc_perc_cfgs.channels = channels
        glob_indu_cfgs.channels = channels
        self.loc_perc = loc_perc.build(loc_perc_cfgs)
        self.glob_indu = glob_indu.build(glob_indu_cfgs)

        self.softmax = nn.Softmax(dim=-1)
        self.out_proj = nn.Linear(channels, channels, bias=bias)

    def forward(self, coordinates, features):
        loc_attn_weights, v_loc = self.loc_perc(coordinates, features)
        glob_attn_weights, v_glob = self.glob_indu(features)

        N = glob_attn_weights.shape[1]
        scale = torch.sqrt(torch.tensor(self.channels, dtype=torch.float32))
        attn_weights = self.softmax(torch.cat((loc_attn_weights, glob_attn_weights[:, :, None]), dim=-1) / scale)
        v = torch.cat((v_loc, v_glob[:, None].repeat(1, N, 1, 1)), dim=-2)
        attn = self.out_proj(attn_weights @ v).squeeze(-2)
        return attn


@point_focused.register_module("withoutMultiscale")
class withoutMultiscale(nn.Module):
    def __init__(self, channels, loc_perc_cfgs, glob_indu_cfgs, bias=True):
        super(withoutMultiscale, self).__init__()
        self.channels = channels

        self.loc_perc = loc_perc.build(loc_perc_cfgs)
        self.glob_indu = glob_indu.build(glob_indu_cfgs)

        self.out_proj = nn.Linear(channels, channels, bias=bias)

    def forward(self, coordinates, features):
        loc_attn_weights, v_loc = self.loc_perc(coordinates, features)
        glob_attn_weights, v_glob = self.glob_indu(features)

        scale = torch.sqrt(torch.tensor(self.channels, dtype=torch.float32))

        loc_attn_weights = F.softmax(loc_attn_weights / scale, dim=-1)
        glob_attn_weights = F.softmax(glob_attn_weights / scale, dim=-1)
        attn = self.out_proj((loc_attn_weights @ v_loc).squeeze(dim=-2) + glob_attn_weights @ v_glob)
        return attn


@point_focused.register_module("withoutFocusedAttn")
class withoutFocusedAttention(nn.Module):
    def __init__(self, channels, loc_perc_cfgs, glob_indu_cfgs, bias=True):
        super(withoutFocusedAttention, self).__init__()
        self.channels = channels

        self.glob_indu = glob_indu.build(glob_indu_cfgs)

        self.out_proj = nn.Linear(channels, channels, bias=bias)

    def forward(self, coordinates, features):
        glob_attn_weights, v_glob = self.glob_indu(features)

        scale = torch.sqrt(torch.tensor(self.channels, dtype=torch.float32))

        glob_attn_weights = F.softmax(glob_attn_weights / scale, dim=-1)
        attn = self.out_proj(glob_attn_weights @ v_glob)
        return attn


@point_focused.register_module("withoutGlobalIndu")
class withoutGlobalInducing(nn.Module):
    def __init__(self, channels, loc_perc_cfgs, glob_indu_cfgs, bias=True):
        super(withoutGlobalInducing, self).__init__()
        self.channels = channels

        loc_perc_cfgs.channels = channels
        glob_indu_cfgs.channels = channels
        self.loc_perc = loc_perc.build(loc_perc_cfgs)
        self.out_proj = nn.Linear(channels, channels, bias=bias)

    def forward(self, coordinates, features):
        loc_attn_weights, v_loc = self.loc_perc(coordinates, features)

        scale = torch.sqrt(torch.tensor(self.channels, dtype=torch.float32))

        loc_attn_weights = F.softmax(loc_attn_weights / scale, dim=-1)
        attn = self.out_proj((loc_attn_weights @ v_loc).squeeze(dim=-2))
        return attn


if __name__ == '__main__':
    device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

    cfgs = {
        'channels': 48,
        'loc_perc_cfgs': {
            'NAME': 'LocalityPerception',
            'channels': 48,
            'neighbours': 3,
            'loc_pos_cfgs': {
                'NAME': 'withLocPos',
                'channels': 48
            },
        },
        'glob_indu_cfgs': {
            'NAME': 'GlobalInducing',
            'inducing': [128, 48],
            'channels': 48,
            'glob_pos_cfgs': {
                'NAME': 'withGlobPos'
            },
        },
    }

    batch, N, channels = 8, 1024, 48
    coordinates, features = torch.rand((batch, N, 3)).to(device), torch.rand((batch, N, channels)).to(device)

    focused_attn = PointFocusedAttention(**cfgs)
    focused_attn.to(device)

    attn = focused_attn(coordinates, features)
    print(attn.shape)
